import re
import json_repair
import requests
import random
import json

def load_config(config_path):
    """Load configuration file"""
    import json
    import os
    
    # Default configuration
    DEFAULT_CONFIG = {
        "cases": [1, 2, 3, 4, 5],
        "model": "gpt-4o",
        "api_url": "https://api.openai.com/v1/chat/completions",
        "api_key": "sk-your-api-key",
        "is_reasoning_model": False,
        "temperature": 0.7,
        "cost_input": 2.5e-6,
        "cost_output": 10e-6,
        "max_retries": 3,
        "bench_file": "PersonaEval_example_cases.csv",
        "output_dir": "results"
    }
    
    if not os.path.exists(config_path):
        print(f"Config file {config_path} does not exist, using default configuration")
        return DEFAULT_CONFIG
    
    with open(config_path, 'r', encoding='utf-8') as f:
        config = json.load(f)
    
    # Merge with default configuration
    for key, value in DEFAULT_CONFIG.items():
        if key not in config:
            config[key] = value
    
    return config

def parse_response(response, option1, option2, option3, option4):
    """Parse LLM response to extract probability distribution"""
    pattern = r"```\s*(.+?)\s*```"
    matches = re.findall(pattern, response, re.DOTALL)
    if matches:
        json_str = matches[-1]  # Take the last match
    else:
        json_str = response
    
    try:
        parsed_obj = json_repair.loads(json_str)
        assert option1 in parsed_obj
        assert option2 in parsed_obj
        assert option3 in parsed_obj
        assert option4 in parsed_obj
        parsed_obj[option1] = float(parsed_obj[option1])
        parsed_obj[option2] = float(parsed_obj[option2])
        parsed_obj[option3] = float(parsed_obj[option3])
        parsed_obj[option4] = float(parsed_obj[option4])
        if round(sum(list(parsed_obj.values())[0:4]), 5) != 1:
            raise ValueError(f"Sum of probabilities is not 1: {parsed_obj}")
        return parsed_obj
    except Exception as e:
        print(response)
        print(f"Error parsing response {response}: {e}")
        raise e

def call_api(prompt, config, option1, option2, option3, option4, gt):
    """Call standard API to get response"""
    messages = [{"role": "user", "content": prompt}]
    payload = {
        "model": config["model"],
        "messages": messages,
        "temperature": config["temperature"],
    }
    headers = {
        "Authorization": f"Bearer {config['api_key']}",
        "Content-Type": "application/json"
    }
    
    response = requests.post(config["api_url"], json=payload, headers=headers, timeout=600).json()
    if "error" in response:
        raise Exception(response["error"])

    cost = response["usage"]["prompt_tokens"] * config["cost_input"] + response["usage"]["completion_tokens"] * config["cost_output"]
    
    response_str = response["choices"][0]["message"]["content"]
    res_json = parse_response(response_str, option1, option2, option3, option4)
    
    # Select option with highest probability
    options = [option1, option2, option3, option4]
    probs = [res_json[opt] for opt in options]
    max_idx = probs.index(max(probs))
    res = options[max_idx]
    
    return res, response_str, res_json, response["usage"]["completion_tokens"], cost

def call_api_stream(prompt, config, option1, option2, option3, option4, gt):
    """Call streaming API to get response (for reasoning models) using requests"""
    messages = [{"role": "user", "content": prompt}]
    
    url = config["api_url"]
    payload = {
        "model": config["model"],
        "messages": messages,
        "temperature": config["temperature"],
        "stream": True,
        "stream_options": {
            "include_usage": True
        }
    }
    headers = {
        "Authorization": f"Bearer {config['api_key']}",
        "Content-Type": "application/json"
    }
    
    response = requests.post(url, json=payload, headers=headers, stream=True, timeout=600)
    if response.status_code != 200:
        error_info = response.json()
        raise Exception(f"API error: {error_info.get('error', error_info)}")
    
    reasoning_content = ""
    content = ""
    usage_info = None
    
    for line in response.iter_lines():
        if line:
            # Skip "data: " prefix
            if line.startswith(b"data: "):
                line = line[6:]
            
            # Skip heartbeat messages
            if line == b"[DONE]":
                break
            
            try:
                # Parse JSON data
                chunk = json.loads(line.decode('utf-8'))
                
                # Check for usage information
                if "usage" in chunk:
                    usage_info = chunk["usage"]
                
                # Extract reasoning_content and content
                if "choices" in chunk and len(chunk["choices"]) > 0:
                    delta = chunk["choices"][0].get("delta", {})
                    if "reasoning_content" in delta and delta["reasoning_content"]:
                        reasoning_content += delta["reasoning_content"]
                    elif "content" in delta and delta["content"] is not None:
                        content += delta["content"]
            except json.JSONDecodeError:
                continue
    
    cost = 0
    if usage_info:
        cost = usage_info["prompt_tokens"] * config["cost_input"] + usage_info["completion_tokens"] * config["cost_output"]
    
    res_json = parse_response(content, option1, option2, option3, option4)
    
    # Select option with highest probability
    options = [option1, option2, option3, option4]
    probs = [res_json[opt] for opt in options]
    max_idx = probs.index(max(probs))
    res = options[max_idx]
    
    full_response = "<thinking>" + reasoning_content + "</thinking>\n\n" + content
    return res, full_response, res_json, usage_info["completion_tokens"], cost 